#!/usr/bin/python

# make the clustering class-specific

import sys,os,re,glob,math,glob,signal,traceback,codecs
import matplotlib,tables
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from scipy.ndimage import interpolation
from pylab import *
from optparse import OptionParser
from multiprocessing import Pool
import ocrolib
from ocrolib import number_of_processors,die
from ocrolib.ligatures import lig
from ocrolib import lineseg,morph,linerec,improc,lineproc
import cPickle
from ocrolib.toplevel import *

import warnings,numpy
warnings.simplefilter('ignore',numpy.RankWarning) # disable rank warnings from polyfit

signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

# these options control alignment
import argparse
parser = argparse.ArgumentParser(description = """
Computes recognition lattices for text lines.  Also displays the bestpath
result (recognition result without language model).
""")
parser.add_argument("-x","--exec",dest="execute",help="execute before anything else (usually used for imports)",default="None")
parser.add_argument("-s","--segmenter",help="segmenter",default="lineseg.DPSegmentLine()")
parser.add_argument("-m","--model",help="model file",default=ocrolib.default.model)
parser.add_argument("-w","--whitespace",help="space model file",default=ocrolib.default.space)
parser.add_argument("-Q","--parallel",type=int,default=1,help="number of parallel processes to use")
parser.add_argument("--show",help="show progress",action="store_true")
parser.add_argument('--delay',type=float,default=0.1,help='time to wait after displaying results')
parser.add_argument("--extract",default=None,help="extract characters for cmodel training")
parser.add_argument("--noglob",action="store_true",help="don't perform expansion on the arguments")
parser.add_argument('-q','--quiet',action="store_true",help="don't output progress info")
parser.add_argument('-S','--sizemode',default="perline",help="how to resize characters for extraction")
parser.add_argument('--hdfmode',default="a",help="how the HDF5 file is opened")
parser.add_argument('-e','--exitonerr',action="store_true",help='time to wait after displaying results')

# parser.add_argument("-r","--rejects",help="generate rejects in lattices",action="store_true")
# parser.add_argument("-R","--maxrange",help="max range for grouping",type=int,default=4)
# parser.add_argument("-L","--ligatures",help="output ligatures",action="store_true")
# parser.add_argument("-O","--overwrite",help="overwrite rseg/cseg files",action="store_true")

parser.add_argument("files",default=[],nargs='*',help="input lines")
args = parser.parse_args()
if not args.noglob: args.files = ocrolib.glob_all(args.files)

exec args.execute
cmodel = ocrolib.ocropus_find_file(args.model)

###
### This first option doesn't actually recognize anything, it extracts
### characters for cmodel training.  The reason it is in `ocropus-linerec`
### is because this is the only program that needs to know about normalization.
###

class Hdf5Writer:
    def __init__(self,fname,mode="w",size=(32,32)):
        from tables import Float32Atom,Int64Atom,StringAtom,Filters
        self.size = size
        h5 = tables.openFile(fname,mode)
        if "patches" not in dir(h5.root):
            h5.createEArray(h5.root,'patches',Float32Atom(),shape=(0,)+size,filters=Filters(9))
            h5.createEArray(h5.root,'classes',Int64Atom(),shape=(0,),filters=tables.Filters(9))
            #h5.createVLArray(h5.root,'files',StringAtom(120),filters=Filters(9))
            #h5.createEArray(h5.root,'bboxes',Float32Atom(shape=(4,)),shape=(0,),filters=Filters(9))
        self.h5 = h5
    def __enter__(self):
        return self
    def __exit__(self,*args):
        self.h5.close()
        self.h5 = None
    def insert(self,image,cls,cost=0.0,count=0,fname="",lgeo=None,bbox=(-1,-1,-1,-1)):
        assert image.shape==self.size,"wrong image shape: %s"%(image.shape,)
        h5 = self.h5
        h5.root.patches.append([image])
        h5.root.classes.append([lig.ord(cls)])
        #h5.root.files.append(fname)
        #h5.root.bboxes.append([array(bbox,'f')])

if args.extract is not None:
    target_xheight = 16
    if args.quiet: print "extracting..."
    try:
        with Hdf5Writer(args.extract,mode=args.hdfmode) as h5:
            for fname in args.files:
                gname = ocrolib.fvariant(fname,"aligned")
                if not os.path.exists(gname): 
                    gname = ocrolib.fvariant(fname,"txt","gt")
                if not os.path.exists(gname): 
                    print fname,"=EXTRACTED=","*** NO ALIGNED TEXT ***",gname
                    continue
                cname = ocrolib.fvariant(fname,"cseg")
                if not os.path.exists(cname): 
                    print fname,"=EXTRACTED=","*** NO CSEG ***",cname
                    continue
                rname = ocrolib.fvariant(fname,"rseg")
                if not os.path.exists(rname): 
                    print fname,"=EXTRACTED=","*** NO RSEG ***",rname
                    continue
                gt = ocrolib.gt_explode(ocrolib.read_text(gname))
                if len(gt)==0: 
                    print fname,"=EXTRACTED=","*** EMPTY GT ***"
                    continue
                if gt[-1]=="\n": gt = gt[:-1]
                cseg = ocrolib.read_line_segmentation(cname)
                rseg = ocrolib.read_line_segmentation(rname)
                csegs = linerec.extract_csegs(cseg)
                maxseg = amax([c.last for c in csegs])
                if maxseg!=len(gt):
                    print fname,"=EXTRACTED=","*** maxseg AND aligned lengths DIFFER***",len(gt),maxseg
                    continue
                csegs = [c.replace(out=[(gt[c.first-1],0.0)]) for i,c in enumerate(csegs)]
                csegs = [c for c in csegs if c.out[0][0]!="~"]
                rsegs = linerec.extract_rsegs(rseg)
                misseg = linerec.extract_non_csegs(rsegs,csegs)
                misseg = [c.replace(out=[("~",0.0)]) for c in misseg]
                if args.show:
                    ion(); gray()
                    figure(1); ocrolib.showgrid([c.img for c in csegs[:100]],xlabels=[c.out[0][0] for c in csegs]) 
                    figure(2); ocrolib.showgrid([m.img for m in misseg[:100]],xlabels=[m.out[0][0] for m in misseg]) 
                    ginput(1,args.delay)
                # TODO optionally double-check against model here
                if not args.quiet: 
                    print fname,"=EXTRACTED=",ocrolib.gt_implode(gt)
                if args.sizemode=="perline":
                    limage = ocrolib.read_image_gray(ocrolib.fvariant(fname,"png"))
                    if len(csegs)<10: 
                        print "too few segments (perline)"
                        continue
                    if limage.shape[1]<100:
                        print "too small (perline)"
                        continue
                    xheight,baseline = lineproc.estimate_xheight(1-limage)
                    if xheight<8 or xheight>40: # TODO make these arguments
                        print "bad xheight:",xheight
                        continue
                for c in csegs+misseg:
                    image = c.img
                    if args.sizemode=="perline":
                        options = dict(bar=(baseline-xheight,baseline))
                        image = improc.line_normalize(image,scale=target_xheight*1.0/xheight,**options)
                    elif args.sizemode=="perchar":
                        image = improc.classifier_normalize(image)
                    cls = c.out[0][0]
                    h5.insert(image=image,cls=cls)
    except list:
        traceback.print_exc()
        sys.exit(1)
    sys.exit(0)

###
### This is the actual line recognizer
###


print "loading",cmodel
cmodel = ocrolib.load_component(cmodel)
wmodel = ocrolib.ocropus_find_file(args.whitespace)

print "loading",wmodel
wmodel = ocrolib.load_component(wmodel)

print "segmenter",args.segmenter
segmenter = eval(args.segmenter)

def invert_image(image):
    """Invert the input image."""
    return amax(image)-image

def loutputs(image,floor=1e-6,keep_rejects=0):
    """Compute the negative log probability (cost) for classifications
    of the character image."""
    outputs = []
    try:
        image = improc.classifier_normalize(image)
    except:
        # FIXME do something more intelligent here, or just have it not fail
        return [("~",30)]
    outputs = cmodel.coutputs(image)
    outputs = [(cls,-log(max(p,floor))) for cls,p in outputs if "~" not in cls or keep_rejects]
    if len(outputs)<1: outputs = [("~",30)]
    return outputs

def make_connected(rsegs,insert=[("~",30.0)]):
    """Given a list of segmentation hypotheses, inserts
    reject classes for any segments that are not present.
    This reconnects the graph in case one of the previous
    steps has rejected some segment entirely."""
    if len(rsegs)<1: return
    transitions = [(r.first,r.last) for r in rsegs]
    lo = amin([r.first for r in rsegs])
    hi = amax([r.first for r in rsegs])
    for i in range(lo,hi+1):
        if (i,i) not in transitions:
            rsegs.append(linerec.Segment(first=i,last=i,img=zeros((1,1)),
                                         bbox=(slice(0,1),slice(0,1)),
                                         out=insert,sp=array([1.0,0])))
    return rsegs

if args.show: ion(); gray()

def process1(fname):
    try:
        if not args.quiet: print fname,"=RAW=",

        # read the image and display it
        
        image = ocrolib.read_image_gray(fname)

        if args.show:
            figure(2); clf()
            figure(1); clf(); subplot(311)
            imshow(image)
            ginput(1,0.1)

        # generate character candidates

        rseg = segmenter.charseg(image)
        rseg = morph.renumber_by_xcenter(rseg)
        rsegs = linerec.extract_rsegs(rseg)
        if len(rsegs)<1:
            if args.quiet: print fname,"=RAW=",
            print "NO RAW SEGMENTS"
            return

        if args.show:
            figure(1); subplot(312)
            morph.showlabels(rseg)
            ginput(1,0.1)

        # classify each character

        recognized = [r.replace(out=loutputs(r.img)) for r in rsegs]

        # compute whitespace probabilities
        
        wmodel.setLine(invert_image(image))
        recognized = [r.replace(sp=wmodel.classifySpace(r.bbox[1].stop)) for r in recognized]

        if args.show:
            figure(2)
            labels = [(r.out[0][0] if r.out and r.out[0][1]<1 else "_") for r in recognized]
            ocrolib.showgrid([r.img for r in recognized][:60],cols=20,xlabels=labels)
            
        # make sure the resulting graph is connected

        recognized = make_connected(recognized)

        # output the best path without a language model for debugging

        labels,costs,trans = linerec.bestpath(recognized,noreject=0)

        if labels is None:
            if args.quiet: print fname,"=RAW=",
            print "FAILED (no bestpath)"
            return

        if not args.quiet: print "".join(labels)

        # write the lattice and the raw segmentation
        
        with open(ocrolib.fvariant(fname,"lattice"),"w") as stream:
            linerec.write_lattice(stream,recognized)

        ocrolib.write_line_segmentation(ocrolib.fvariant(fname,"rseg"),rseg)
        
        if args.show:
            ginput(1,args.delay)
    except ocrolib.RecognitionError,e:
        print fname,":",e
        return
    except:
        print "ERROR IN",fname
        traceback.print_exc()
        if args.exitonerr: sys.exit(1)
        return

###
### top level: either run sequentially or in parallel under multiprocessing
###

print "recognizing",len(args.files),"files"
if args.show:
    args.parallel = 1
if args.parallel==1:
    for fname in args.files:
        process1(fname)
else:
    pool = Pool(processes=args.parallel)
    result = []
    for r in pool.imap_unordered(process1,args.files):
        result.append(r)
        if len(result)%100==0: print "==========",len(result),len(args),"=========="
